package dac

import (
	"database/sql/driver"
	"encoding/binary"
	"fmt"
	"sync/atomic"
	"syscall"
	"unsafe"
)

const (
	S_OK           = 0x00000000
	S_FALSE        = 0x00000001
	E_UNEXPECTED   = 0x8000FFFF
	E_NOTIMPL      = 0x80004001
	E_OUTOFMEMORY  = 0x8007000E
	E_INVALIDARG   = 0x80070057
	E_NOINTERFACE  = 0x80004002
	E_POINTER      = 0x80004003
	E_HANDLE       = 0x80070006
	E_ABORT        = 0x80004004
	E_FAIL         = 0x80004005
	E_ACCESSDENIED = 0x80070005
	E_PENDING      = 0x8000000A

	// flags passed as the coInit parameter to CoInitializeEx.
	//{$EXTERNALSYM COINIT_MULTITHREADED}
	COINIT_MULTITHREADED = 0 // OLE calls objects on any thread.
	//{$EXTERNALSYM COINIT_APARTMENTTHREADED}
	COINIT_APARTMENTTHREADED = 2 // Apartment model
	//{$EXTERNALSYM COINIT_DISABLE_OLE1DDE}
	COINIT_DISABLE_OLE1DDE = 4 // Dont use DDE for Ole1 support.
	//{$EXTERNALSYM COINIT_SPEED_OVER_MEMORY}
	COINIT_SPEED_OVER_MEMORY = 8 // Trade memory for speed.
)

type DacDBErr error

type DacLib struct {
	newDac_n      int32
	freeDac_n     int32
	query_new_n   int32
	query_close_n int32
	begin_trans_n int32
	end_trans_n   int32

	libHandle            syscall.Handle
	libDacCoInitialize   uintptr
	libDacCoInitializeR  uintptr
	libDacCoInitializeEx uintptr
	libDacCoUninitialize uintptr
	libNewDacConn        uintptr
	libFreeDacConn       uintptr
	libBeginTrans        uintptr
	libCommitTrans       uintptr
	libRollbackTrans     uintptr
	libGetDacConnLastErr uintptr

	libNewQuery       uintptr
	libExecQuery      uintptr
	libSetQueryText   uintptr
	libSetQueryParams uintptr
	libOpenQuery      uintptr
	libCloseQuery     uintptr

	libGetQueryColumns                    uintptr
	libGetQueryColumnTypeDatabaseTypeName uintptr

	libGetDacQueryLastErr   uintptr
	libGetQueryRecordValues uintptr
	libQueryNext            uintptr
	libFile                 string
}

func NewDacLib(libFile string) *DacLib {
	rval := &DacLib{
		libFile: libFile,
	}
	return rval
}

func innerUTF8FromPtr(r1 uintptr) string {
	p := (*byte)(unsafe.Pointer(r1))
	data := make([]byte, 0, 512)
	for *p != 0 {
		data = append(data, *p)
		r1 += unsafe.Sizeof(byte(0))
		p = (*byte)(unsafe.Pointer(r1))
	}
	return string(data)
}

func readUInt32(ptr uintptr) uint32 {
	data := make([]byte, 4)
	data[0] = *(*byte)(unsafe.Pointer(ptr))
	data[1] = *(*byte)(unsafe.Pointer(ptr + 1))
	data[2] = *(*byte)(unsafe.Pointer(ptr + 2))
	data[3] = *(*byte)(unsafe.Pointer(ptr + 3))
	return binary.LittleEndian.Uint32(data)
}

func readNBuf(ptr uintptr, n int, buf []byte) {
	for i := 0; i < n; i++ {
		buf[i] = *(*byte)(unsafe.Pointer(ptr + uintptr(i)))
	}
}

func (this *DacLib) StatusString() string {
	return fmt.Sprintf("conn:%d/%d, query:%d/%d, rows:%d/%d, trans:%d/%d", this.newDac_n, this.freeDac_n, this.query_new_n, this.query_close_n, dbrows_new_cnt, dbrows_free_cnt, this.begin_trans_n, this.end_trans_n)
}

func (this *DacLib) checkLoadLib() error {
	if this.libHandle != 0 {
		return nil
	}
	h, err := syscall.LoadLibrary(this.libFile)
	if err != nil {
		return err
	}
	defer func() {
		if this.libHandle == 0 {
			syscall.FreeLibrary(h)
		}
	}()

	{
		add, err := syscall.GetProcAddress(h, "NewDacConn")
		if err != nil {
			return err
		}
		this.libNewDacConn = add
	}

	{
		add, err := syscall.GetProcAddress(h, "FreeDacConn")
		if err != nil {
			return err
		}
		this.libFreeDacConn = add
	}

	{
		add, err := syscall.GetProcAddress(h, "DacCoInitialize")
		if err != nil {
			return err
		}
		this.libDacCoInitialize = add
	}

	{
		add, err := syscall.GetProcAddress(h, "DacCoInitializeR")
		if err != nil {
			this.libDacCoInitializeR = 0
		} else {
			this.libDacCoInitializeR = add
		}
	}

	{
		add, err := syscall.GetProcAddress(h, "DacCoInitializeEx")
		if err != nil {
			this.libDacCoInitializeEx = 0
		} else {
			this.libDacCoInitializeEx = add
		}
	}

	{
		add, err := syscall.GetProcAddress(h, "DacCoUninitialize")
		if err != nil {
			return err
		}
		this.libDacCoUninitialize = add
	}

	{
		add, err := syscall.GetProcAddress(h, "BeginTrans")
		if err != nil {
			return err
		}
		this.libBeginTrans = add
	}

	{
		add, err := syscall.GetProcAddress(h, "CommitTrans")
		if err != nil {
			return err
		}
		this.libCommitTrans = add
	}

	{
		add, err := syscall.GetProcAddress(h, "RollbackTrans")
		if err != nil {
			return err
		}
		this.libRollbackTrans = add
	}

	{
		add, err := syscall.GetProcAddress(h, "GetDacConnLastErr")
		if err != nil {
			return err
		}
		this.libGetDacConnLastErr = add
	}

	{
		add, err := syscall.GetProcAddress(h, "NewQuery")
		if err != nil {
			return err
		}
		this.libNewQuery = add
	}

	{
		add, err := syscall.GetProcAddress(h, "SetQueryText")
		if err != nil {
			return err
		}
		this.libSetQueryText = add
	}

	{
		add, err := syscall.GetProcAddress(h, "SetQueryParams")
		if err != nil {
			return err
		}
		this.libSetQueryParams = add
	}

	{
		add, err := syscall.GetProcAddress(h, "ExecQuery")
		if err != nil {
			return err
		}
		this.libExecQuery = add
	}

	{
		add, err := syscall.GetProcAddress(h, "OpenQuery")
		if err != nil {
			return err
		}
		this.libOpenQuery = add
	}

	{
		add, err := syscall.GetProcAddress(h, "CloseQuery")
		if err != nil {
			return err
		}
		this.libCloseQuery = add
	}

	{
		add, err := syscall.GetProcAddress(h, "GetDacQueryLastErr")
		if err != nil {
			return err
		}
		this.libGetDacQueryLastErr = add
	}

	{
		add, err := syscall.GetProcAddress(h, "GetQueryColumns")
		if err != nil {
			return err
		}
		this.libGetQueryColumns = add
	}

	{
		add, err := syscall.GetProcAddress(h, "GetQueryColumnTypeDatabaseTypeName")
		if err != nil {
			return err
		}
		this.libGetQueryColumnTypeDatabaseTypeName = add
	}

	{
		add, err := syscall.GetProcAddress(h, "GetQueryRecordValues")
		if err != nil {
			return err
		}
		this.libGetQueryRecordValues = add
	}

	{
		add, err := syscall.GetProcAddress(h, "QueryNext")
		if err != nil {
			return err
		}
		this.libQueryNext = add
	}
	this.libHandle = h
	return nil
}

func (this *DacLib) NewDACConn(str string) uintptr {
	this.checkLoadLib()
	if this.libNewDacConn == 0 {
		return 0
	}
	strBuf := append([]byte(str), 0)
	r1, _, _ := syscall.Syscall(this.libNewDacConn, 2, uintptr(unsafe.Pointer(&strBuf[0])), uintptr(len(strBuf)), 0)
	atomic.AddInt32(&this.newDac_n, 1)
	return r1
}

func (this *DacLib) CoUninitialize() {
	this.checkLoadLib()
	if this.libDacCoInitialize == 0 {
		return
	}
	syscall.Syscall(this.libDacCoUninitialize, 0, 0, 0, 0)
}

func (this *DacLib) CoInitializeEx(flag int) bool {
	this.checkLoadLib()

	if this.libDacCoInitializeEx == 0 {
		return false
	}
	r1, _, _ := syscall.Syscall(this.libDacCoInitializeEx, 1, uintptr(flag), 0, 0)
	return r1 == S_OK
}

func (this *DacLib) CoInitialize() bool {
	this.checkLoadLib()

	if this.libDacCoInitializeEx != 0 {
		r1, _, _ := syscall.Syscall(this.libDacCoInitializeEx, 1, uintptr(COINIT_MULTITHREADED), 0, 0)
		return r1 == S_OK
	}

	if this.libDacCoInitialize == 0 {
		return false
	}
	syscall.Syscall(this.libDacCoInitialize, 0, 0, 0, 0)
	return false // 避免释放
}

func (this *DacLib) BeginTrans(conn uintptr) error {
	this.checkLoadLib()
	if this.libBeginTrans == 0 {
		return fmt.Errorf("load libDac.dll exception!!!")
	}
	r1, _, _ := syscall.Syscall(this.libBeginTrans, 1, conn, 0, 0)
	recN := int32(r1)
	if recN == -1 {
		msg := this.GetConnLastErr(conn)
		return fmt.Errorf("%s", msg)
	}
	atomic.AddInt32(&this.begin_trans_n, 1)
	return nil
}

func (this *DacLib) RollbackTrans(conn uintptr) error {
	this.checkLoadLib()

	if this.libHandle == 0 || this.libRollbackTrans == 0 {
		return nil
	}

	atomic.AddInt32(&this.end_trans_n, 1)

	r1, _, _ := syscall.Syscall(this.libRollbackTrans, 1, conn, 0, 0)
	recN := int32(r1)
	if recN == -1 {
		msg := this.GetConnLastErr(conn)
		return fmt.Errorf("%s", msg)
	}
	return nil
}

func (this *DacLib) CommitTrans(conn uintptr) error {
	this.checkLoadLib()

	if this.libHandle == 0 || this.libCommitTrans == 0 {
		return nil
	}

	atomic.AddInt32(&this.end_trans_n, 1)

	r1, _, _ := syscall.Syscall(this.libCommitTrans, 1, conn, 0, 0)
	recN := int32(r1)
	if recN == -1 {
		msg := this.GetConnLastErr(conn)
		return fmt.Errorf("%s", msg)
	}
	return nil
}

func (this *DacLib) GetConnLastErr(conn uintptr) string {
	this.checkLoadLib()
	if this.libGetDacConnLastErr == 0 {
		return "load libDac.dll exception!!!"
	}
	strBuf := make([]byte, 2048)
	r1, _, _ := syscall.Syscall(this.libGetDacConnLastErr, 3, conn, uintptr(unsafe.Pointer(&strBuf[0])), uintptr(len(strBuf)))
	if r1 > 0 {
		return string(strBuf[:r1])
	}
	return ""
}

func (this *DacLib) NewQuery(conn uintptr) uintptr {
	this.checkLoadLib()
	if this.libNewQuery == 0 {
		return 0
	}
	r1, _, _ := syscall.Syscall(this.libNewQuery, 1, conn, 0, 0)
	atomic.AddInt32(&this.query_new_n, 1)
	return r1
}

func (this *DacLib) SetQueryText(qry uintptr, txt string) int {
	this.checkLoadLib()
	if this.libSetQueryText == 0 {
		return -1
	}
	strBuf := append([]byte(txt), 0)
	r1, _, _ := syscall.Syscall(this.libSetQueryText, 3, qry, uintptr(unsafe.Pointer(&strBuf[0])), uintptr(len(strBuf)))
	return int(int32(r1))
}

func (this *DacLib) WriteParamsIntV(buf []byte, v int64) []byte {
	buf = append(buf, byte(2))                 // int
	buf = BytesAppendUInt64_LE(buf, uint64(v)) // v
	return buf
}

func (this *DacLib) WriteParamsStringV(buf []byte, v string) []byte {
	buf = append(buf, byte(0)) // type string
	strBuf := []byte(v)
	buf = BytesAppendUInt32_LE(buf, uint32(len(strBuf)))
	buf = append(buf, strBuf...)
	return buf
}

func (this *DacLib) SetQueryParams(qry uintptr, params []driver.Value) error {
	this.checkLoadLib()
	if this.libSetQueryParams == 0 {
		return fmt.Errorf("load libDac.dll exception!!!")
	}
	buf := make([]byte, 0, 2048)
	buf = append(buf, byte(len(params)))
	for i := 0; i < len(params); i++ {
		v := params[i]
		switch v1 := v.(type) {
		case int:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case uint:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case int64:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case uint64:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case int8:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case uint8:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case int16:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case uint16:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case uint32:
			buf = this.WriteParamsIntV(buf, int64(v1))
		case int32:
			buf = this.WriteParamsIntV(buf, int64(v1))
		default:
			buf = this.WriteParamsStringV(buf, fmt.Sprintf("%v", v1))
		}
	}
	syscall.Syscall(this.libSetQueryParams, 2, qry, uintptr(unsafe.Pointer(&buf[0])), 0)
	return nil
}

func (this *DacLib) GetQueryLastErr(qry uintptr) string {
	this.checkLoadLib()

	strBuf := make([]byte, 2048)
	r1, _, _ := syscall.Syscall(this.libGetDacQueryLastErr, 3, qry, uintptr(unsafe.Pointer(&strBuf[0])), uintptr(len(strBuf)))
	if r1 > 0 {
		return string(strBuf[:r1])
	}
	return ""
}

func (this *DacLib) OpenQuery(qry uintptr) (int, error) {
	this.checkLoadLib()
	if this.libOpenQuery == 0 {
		return -1, fmt.Errorf("load libDac.dll exception!!!")
	}
	r1, _, _ := syscall.Syscall(this.libOpenQuery, 1, qry, 0, 0)
	recN := int32(r1)
	if recN == -1 {
		msg := this.GetQueryLastErr(qry)
		return 0, fmt.Errorf("%s", msg)
	}
	return int(recN), nil
}

func (this *DacLib) ExecQuery(qry uintptr) (int, error) {
	this.checkLoadLib()
	if this.libExecQuery == 0 {
		return -1, fmt.Errorf("load libDac.dll exception!!!")
	}
	r1, _, _ := syscall.Syscall(this.libExecQuery, 1, qry, 0, 0)
	recN := int32(r1)
	if recN == -1 {
		msg := this.GetQueryLastErr(qry)
		return 0, fmt.Errorf("%s", msg)
	}
	return int(recN), nil
}

func (this *DacLib) GetQueryColumns(qry uintptr) string {
	this.checkLoadLib()
	r1, _, _ := syscall.Syscall(this.libGetQueryColumns, 1, qry, 0, 0)
	if int32(r1) == 0 {
		return ""
	}
	return innerUTF8FromPtr(r1)
}

func (this *DacLib) GetQueryColumnTypeDatabaseTypeName(qry uintptr, idx int) string {
	this.checkLoadLib()
	r1, _, _ := syscall.Syscall(this.libGetQueryColumnTypeDatabaseTypeName, 2, qry, uintptr(idx), 0)
	if int32(r1) == 0 {
		return ""
	}
	return innerUTF8FromPtr(r1)
}

func (this *DacLib) QueryNext(qry uintptr) {
	this.checkLoadLib()
	syscall.Syscall(this.libQueryNext, 1, qry, 0, 0)
}

func (this *DacLib) CloseQuery(qry uintptr) {
	this.checkLoadLib()
	if this.libCloseQuery == 0 {
		return
	}
	syscall.Syscall(this.libCloseQuery, 1, qry, 0, 0)
	atomic.AddInt32(&this.query_close_n, 1)

}

func (this *DacLib) CloseConn(conn uintptr) {
	this.checkLoadLib()
	if this.libFreeDacConn == 0 {
		return
	}
	syscall.Syscall(this.libFreeDacConn, 1, conn, 0, 0)
	atomic.AddInt32(&this.freeDac_n, 1)
}

func (this *DacLib) GetQueryRecordValues(qry uintptr, n int, dest []driver.Value) {
	this.checkLoadLib()
	r1, _, _ := syscall.Syscall(this.libGetQueryRecordValues, 1, qry, 0, 0)
	//p := (*byte)(unsafe.Pointer(r1))
	for i := 0; i < n; i++ {
		typ := *(*byte)(unsafe.Pointer(r1))
		r1 += 1
		if typ == 255 {
			dest[i] = nil
		} else if typ == 2 { // boolean
			v := *(*byte)(unsafe.Pointer(r1))
			r1 += 1
			if v == 1 {
				dest[i] = true
			} else {
				dest[i] = false
			}
		} else {
			l1 := readUInt32(r1)
			r1 += 4
			buf := make([]byte, l1)
			readNBuf(r1, int(l1), buf)
			r1 += uintptr(l1)
			if typ == 0 {
				dest[i] = string(buf)
			} else {
				dest[i] = buf
			}
		}

	}
	return
}
